import torch
from torch import nn

print(torch.tril(torch.ones([3, 3], dtype=torch.bool)))

